import json
import pickle
from tqdm import  tqdm
import os


dir_list = ["line_predictor_less_query_split"]
output_dir = "line_predictor_less_query_merge/"

for directory_name in dir_list:
    data_names = ["FB15K"]

    all_files = os.listdir(directory_name)
    sample_data_path = directory_name + "/"

    def merge_query_file(query_file_dict_list):
        """
        The query file list is a list of dictionary of the train/validation/test queries that are separately sampled
        """
        merged_dict = {}

        for query_file_dict in query_file_dict_list:
            for query_type in query_file_dict.keys():
                if query_type in merged_dict:
                    for query, answer_dict in query_file_dict[query_type].items():
                        merged_dict[query_type][query] = answer_dict
                else:
                    merged_dict[query_type] = {}
                    for query, answer_dict in query_file_dict[query_type].items():
                        merged_dict[query_type][query] = answer_dict

        print({k: len(v) for k, v in merged_dict.items()})

        return merged_dict

    

    for data_name in data_names:
        print(data_name)

        test_data_prefix = data_name  + "_line_predictor_test_queries"

        train_dict_list_same = []
        # print("train")
        i=0
        for file in tqdm(all_files):
            with open(sample_data_path + file, "rb") as fin:
                data_dict = pickle.load(fin)
                train_dict_list_same.append(data_dict)
            if i==2:
                break
            #i+=1

        # print("#same: ", len(train_dict_list_same))
        train_data_dict_same = merge_query_file(train_dict_list_same)

        filehandler = open(output_dir + test_data_prefix  + ".pkl", "wb")
        pickle.dump(train_data_dict_same, filehandler)
        filehandler.close()